use tract_data::half::f16;

unicast_impl_wrap!(
    f16,
    arm64fp16_unicast_mul_f16_32n,
    32,
    8,
    #[inline(never)]
    fn run(a: &mut [f16], b: &[f16]) {
        assert!(a.len() == b.len());
        assert!(a.len() % 32 == 0);
        assert!(a.len() > 0);
        #[target_feature(enable = "fp16")]
        unsafe fn run(a: &mut [f16], b: &[f16]) {
            unsafe {
                let len = a.len();
                let a_ptr = a.as_ptr();
                let b_ptr = b.as_ptr();
                std::arch::asm!("
                2:
                    ld1 {{v0.8h, v1.8h, v2.8h, v3.8h}}, [{a_ptr}]
                    ld1 {{v4.8h, v5.8h, v6.8h, v7.8h}}, [{b_ptr}], 64
                    fmul v0.8h, v0.8h, v4.8h
                    fmul v1.8h, v1.8h, v5.8h
                    fmul v2.8h, v2.8h, v6.8h
                    fmul v3.8h, v3.8h, v7.8h
                    st1 {{v0.8h, v1.8h, v2.8h, v3.8h}}, [{a_ptr}], 64
                    subs {len}, {len}, 32
                    bne 2b
            ",
            len = inout(reg) len => _,
            a_ptr = inout(reg) a_ptr => _,
            b_ptr = inout(reg) b_ptr => _,
            out("v0") _, out("v1") _, out("v2") _, out("v3") _,);
            }
        }
        unsafe { run(a, b) }
    }
);

unicast_impl_wrap!(
    f16,
    arm64fp16_unicast_add_f16_32n,
    32,
    8,
    #[inline(never)]
    fn run(a: &mut [f16], b: &[f16]) {
        assert!(a.len() == b.len());
        assert!(a.len() % 32 == 0);
        assert!(a.len() > 0);
        #[target_feature(enable = "fp16")]
        unsafe fn run(a: &mut [f16], b: &[f16]) {
            unsafe {
                let len = a.len();
                let a_ptr = a.as_ptr();
                let b_ptr = b.as_ptr();
                std::arch::asm!("
                2:
                    ld1 {{v0.8h, v1.8h, v2.8h, v3.8h}}, [{a_ptr}]
                    ld1 {{v4.8h, v5.8h, v6.8h, v7.8h}}, [{b_ptr}], 64
                    fadd v0.8h, v0.8h, v4.8h
                    fadd v1.8h, v1.8h, v5.8h
                    fadd v2.8h, v2.8h, v6.8h
                    fadd v3.8h, v3.8h, v7.8h
                    st1 {{v0.8h, v1.8h, v2.8h, v3.8h}}, [{a_ptr}], 64
                    subs {len}, {len}, 32
                    bne 2b
            ",
            len = inout(reg) len => _,
            a_ptr = inout(reg) a_ptr => _,
            b_ptr = inout(reg) b_ptr => _,
            out("v0") _, out("v1") _, out("v2") _, out("v3") _,);
            }
        }
        unsafe { run(a, b) }
    }
);

unicast_impl_wrap!(
    f16,
    arm64fp16_unicast_sub_f16_32n,
    32,
    8,
    #[inline(never)]
    fn run(a: &mut [f16], b: &[f16]) {
        assert!(a.len() == b.len());
        assert!(a.len() % 32 == 0);
        assert!(a.len() > 0);
        #[target_feature(enable = "fp16")]
        unsafe fn run(a: &mut [f16], b: &[f16]) {
            unsafe {
                let len = a.len();
                let a_ptr = a.as_ptr();
                let b_ptr = b.as_ptr();
                std::arch::asm!("
                2:
                    ld1 {{v0.8h, v1.8h, v2.8h, v3.8h}}, [{a_ptr}]
                    ld1 {{v4.8h, v5.8h, v6.8h, v7.8h}}, [{b_ptr}], 64
                    fsub v0.8h, v0.8h, v4.8h
                    fsub v1.8h, v1.8h, v5.8h
                    fsub v2.8h, v2.8h, v6.8h
                    fsub v3.8h, v3.8h, v7.8h
                    st1 {{v0.8h, v1.8h, v2.8h, v3.8h}}, [{a_ptr}], 64
                    subs {len}, {len}, 32
                    bne 2b
            ",
            len = inout(reg) len => _,
            a_ptr = inout(reg) a_ptr => _,
            b_ptr = inout(reg) b_ptr => _,
            out("v0") _, out("v1") _, out("v2") _, out("v3") _,);
            }
        }
        unsafe { run(a, b) }
    }
);

unicast_impl_wrap!(
    f16,
    arm64fp16_unicast_subf_f16_32n,
    32,
    8,
    #[inline(never)]
    fn run(a: &mut [f16], b: &[f16]) {
        assert!(a.len() == b.len());
        assert!(a.len() % 32 == 0);
        assert!(a.len() > 0);
        #[target_feature(enable = "fp16")]
        unsafe fn run(a: &mut [f16], b: &[f16]) {
            unsafe {
                let len = a.len();
                let a_ptr = a.as_ptr();
                let b_ptr = b.as_ptr();
                std::arch::asm!("
                2:
                    ld1 {{v0.8h, v1.8h, v2.8h, v3.8h}}, [{a_ptr}]
                    ld1 {{v4.8h, v5.8h, v6.8h, v7.8h}}, [{b_ptr}], 64
                    fsub v0.8h, v4.8h, v0.8h
                    fsub v1.8h, v5.8h, v1.8h
                    fsub v2.8h, v6.8h, v2.8h
                    fsub v3.8h, v7.8h, v3.8h
                    st1 {{v0.8h, v1.8h, v2.8h, v3.8h}}, [{a_ptr}], 64
                    subs {len}, {len}, 32
                    bne 2b
            ",
            len = inout(reg) len => _,
            a_ptr = inout(reg) a_ptr => _,
            b_ptr = inout(reg) b_ptr => _,
            out("v0") _, out("v1") _, out("v2") _, out("v3") _,);
            }
        }
        unsafe { run(a, b) }
    }
);

unicast_impl_wrap!(
    f16,
    arm64fp16_unicast_min_f16_32n,
    32,
    8,
    #[inline(never)]
    fn run(a: &mut [f16], b: &[f16]) {
        assert!(a.len() == b.len());
        assert!(a.len() % 32 == 0);
        assert!(a.len() > 0);
        #[target_feature(enable = "fp16")]
        unsafe fn run(a: &mut [f16], b: &[f16]) {
            unsafe {
                let len = a.len();
                let a_ptr = a.as_ptr();
                let b_ptr = b.as_ptr();
                std::arch::asm!("
                2:
                    ld1 {{v0.8h, v1.8h, v2.8h, v3.8h}}, [{a_ptr}]
                    ld1 {{v4.8h, v5.8h, v6.8h, v7.8h}}, [{b_ptr}], 64
                    fmin v0.8h, v0.8h, v4.8h
                    fmin v1.8h, v1.8h, v5.8h
                    fmin v2.8h, v2.8h, v6.8h
                    fmin v3.8h, v3.8h, v7.8h
                    st1 {{v0.8h, v1.8h, v2.8h, v3.8h}}, [{a_ptr}], 64
                    subs {len}, {len}, 32
                    bne 2b
            ",
            len = inout(reg) len => _,
            a_ptr = inout(reg) a_ptr => _,
            b_ptr = inout(reg) b_ptr => _,
            out("v0") _, out("v1") _, out("v2") _, out("v3") _,);
            }
        }
        unsafe { run(a, b) }
    }
);

unicast_impl_wrap!(
    f16,
    arm64fp16_unicast_max_f16_32n,
    32,
    8,
    #[inline(never)]
    fn run(a: &mut [f16], b: &[f16]) {
        assert!(a.len() == b.len());
        assert!(a.len() % 32 == 0);
        assert!(a.len() > 0);
        #[target_feature(enable = "fp16")]
        unsafe fn run(a: &mut [f16], b: &[f16]) {
            unsafe {
                let len = a.len();
                let a_ptr = a.as_ptr();
                let b_ptr = b.as_ptr();
                std::arch::asm!("
                2:
                    ld1 {{v0.8h, v1.8h, v2.8h, v3.8h}}, [{a_ptr}]
                    ld1 {{v4.8h, v5.8h, v6.8h, v7.8h}}, [{b_ptr}], 64
                    fmax v0.8h, v0.8h, v4.8h
                    fmax v1.8h, v1.8h, v5.8h
                    fmax v2.8h, v2.8h, v6.8h
                    fmax v3.8h, v3.8h, v7.8h
                    st1 {{v0.8h, v1.8h, v2.8h, v3.8h}}, [{a_ptr}], 64
                    subs {len}, {len}, 32
                    bne 2b
            ",
            len = inout(reg) len => _,
            a_ptr = inout(reg) a_ptr => _,
            b_ptr = inout(reg) b_ptr => _,
            out("v0") _, out("v1") _, out("v2") _, out("v3") _,);
            }
        }
        unsafe { run(a, b) }
    }
);

#[cfg(test)]
mod test_arm64fp16_unicast_mul_f16_32n {
    use super::*;
    use proptest::strategy::Strategy;
    crate::unicast_frame_tests!(
        crate::arm64::has_fp16(),
        f16,
        arm64fp16_unicast_mul_f16_32n,
        |a, b| a * b
    );
    crate::unicast_frame_tests!(
        crate::arm64::has_fp16(),
        f16,
        arm64fp16_unicast_add_f16_32n,
        |a, b| a + b
    );
    crate::unicast_frame_tests!(
        crate::arm64::has_fp16(),
        f16,
        arm64fp16_unicast_sub_f16_32n,
        |a, b| a - b
    );
    crate::unicast_frame_tests!(
        crate::arm64::has_fp16(),
        f16,
        arm64fp16_unicast_subf_f16_32n,
        |a, b| b - a
    );
    crate::unicast_frame_tests!(
        crate::arm64::has_fp16(),
        f16,
        arm64fp16_unicast_min_f16_32n,
        |a, b| a.min(b)
    );
    crate::unicast_frame_tests!(
        crate::arm64::has_fp16(),
        f16,
        arm64fp16_unicast_max_f16_32n,
        |a, b| a.max(b)
    );
}
